#include <ros/ros.h>
#include <ros/package.h>
#include <chrono>
#include <geometry_msgs/Pose.h>
#include <geometry_msgs/PointStamped.h>
#include <std_msgs/Int32MultiArray.h>
#include <std_msgs/Float32MultiArray.h>
#define USE_OPENCV
#define DETECT_MARKER_X_NET
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>

#include "caffe/caffe.hpp"
#include "caffe/util/spire_camera_reader.hpp"
#include "caffe/util/spire_video_writer.hpp"

#include "caffe/util/spire_cn_detector.h"
#include "DAS_Detect.hpp"
#include "findx.h"

#include <algorithm>
#include <iosfwd>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <time.h>


using namespace caffe;  // NOLINT(build/namespaces)
using namespace cv;


#define PROCESS_WIDTH 640
#define PROCESS_HEIGHT 480

std::string video_saver_dir = "/home/nvidia";
std::string model_file = "/home/nvidia/catkin_ws/src/prometheus_detection_circlex/spire_caffe/examples/spire_x_classification/xnet_deploy.prototxt";
std::string trained_file = "/home/nvidia/catkin_ws/src/prometheus_detection_circlex/spire_caffe/examples/spire_x_classification/xnet_iter_10000.caffemodel";
std::string mean_file = "";
std::string label_file = "/home/nvidia/catkin_ws/src/prometheus_detection_circlex/spire_caffe/examples/spire_x_classification/label.txt";


int mode_select = 2;
void modeSelectionCallback(const std_msgs::Int32MultiArray &msgs)
{
    mode_select = msgs.data[0];
}

ros::Publisher position_pub, bit_pub;
ros::Subscriber mode_sub;

spire::CNEllipseDetector cned;
// MarkerDetector MDetector;

//CameraParameters TheCameraParameters;

double _time_record;
void _tic() {
  _time_record = (double)cv::getTickCount();
}
double _toc() {
  double time_gap = ((double)cv::getTickCount()-_time_record)*1000. / cv::getTickFrequency();
  return time_gap;
  // std::cout << "Cost Time: " << time_gap << " ms" << std::endl;
}

void cx_init();
void cx_detect(Mat& srcIm, Mat3b& resultIm, vector<cv::Point>& pts, vector<float>& axis_bs);
void aruco_init(Mat& srcIm);
void aruco_detect(Mat& srcIm, Mat3b& resultIm, vector<cv::Point>& pts);
void flow01(Mat& srcIm, Mat3b& resultIm, vector<cv::Point>& pts);

int main(int argc, char** argv) {
  ros::init(argc, argv, "fly_detector");
  ros::NodeHandle nh("~");


    if (nh.getParam("video_saver_dir", video_saver_dir)) {
        ROS_INFO("video_saver_dir is %s", video_saver_dir.c_str());
    } else {
        ROS_WARN("didn't find parameter video_saver_dir");
    }

    if (nh.getParam("model_file", model_file)) {
        ROS_INFO("model_file is %s", model_file.c_str());
    } else {
        ROS_WARN("didn't find parameter model_file");
    }

    if (nh.getParam("trained_file", trained_file)) {
        ROS_INFO("trained_file is %s", trained_file.c_str());
    } else {
        ROS_WARN("didn't find parameter trained_file");
    }

    if (nh.getParam("label_file", label_file)) {
        ROS_INFO("label_file is %s", label_file.c_str());
    } else {
        ROS_WARN("didn't find parameter label_file");
    }

  ros::Rate loopRate(50);
  position_pub = nh.advertise<std_msgs::Float32MultiArray>("/prometheus/object_detection/circlex_det", 10);
  bit_pub = nh.advertise<std_msgs::Float32MultiArray>("/uav/vision/fly_detecter_drone", 10);
  // mode_sub = nh.subscribe("/control/mode_selection", 10, modeSelectionCallback);

  // std::cout << "please select camera you use(WEBCAM:1/FLYCAP:2): ";
  char cam_id = '1'; 
  int cam_id_real(0);
  // std::cin >> cam_id;
  if (cam_id == '2') {
  	cam_id_real = -1;
  }

  /****
  std::cout << "please select main algorithm loop:\n" <<
    "1. video recorder\n" <<
    "2. circle detector\n" <<
    "3. aruco detector\n" <<
    "4. 50 meter optical flow\n" <<
    "5. 1 meter optical flow\n" <<
    "6. x detector\n" <<
    "7. 50 to 0 (flow_50 circle flow_1) together\n" <<
    "8. 50 to 0 (flow_50 circle x_detect) together\n" <<
    "0. exit\n" << 
    "your choose: ";
  ****/
  char mode = '2';
  // std::cin >> mode;
  if (mode < '1' || mode > '8') return 0;

  SpireVideoWriter writer;
  writer.SetUp(video_saver_dir, 25.0, Size(PROCESS_WIDTH, PROCESS_HEIGHT));
  // SpireCameraReader layer;
  // layer.SetUp(cam_id_real);
  cv::VideoCapture cap(0);
  cx_init();
  
  DASDetect flow50;
  flow50.init();

  Mat image; Mat3b resultIm;

  int CenterX = PROCESS_WIDTH / 2, CenterY = PROCESS_HEIGHT / 2;
  double totTime(0);
  int countsDt(0), frameCount(0), detectionFps(0);
  bool drawReferenceCenter = true;
  vector<cv::Point> pointsDt;
  vector<float> axisbDt;


  cv::Point xp;
  Rect _rect;
  Point position;

  cap>>image;
  aruco_init(image);
  while(ros::ok()) {
    auto start = std::chrono::system_clock::now();

    _tic();
    // cout<<1<<endl;
    cap >> image;
    // cout<<2<<endl;
    // layer.GetOneFrame(image);
    resize(image, image, Size(PROCESS_WIDTH, PROCESS_HEIGHT));
    pointsDt.clear();
    axisbDt.clear();

    switch(mode) {
      case '1':
        // std::cout << "video recorder" << std::endl;
      // cout<<3<<endl;
        drawReferenceCenter = false;
        image.copyTo(resultIm); 
        break;
      case '2':
        // std::cout << "circle detector" << std::endl;
        cx_detect(image, resultIm, pointsDt, axisbDt);
        break;
      case '3':
        // std::cout << "aruco detector" << std::endl;
        aruco_detect(image, resultIm, pointsDt);
        break;
      case '4':
        // std::cout << "50 meter optical flow" << std::endl;
        frameCount++;
        _rect = flow50.detectMain(image,resultIm,frameCount);
        position.x = _rect.x + _rect.width/2;
        position.y = _rect.y + _rect.height/2;
        if((position.x+position.y) != 0)
        {
        	pointsDt.push_back(position);
        }
        break;
      case '5':
        // std::cout << "1 meter optical flow" << std::endl;
        flow01(image, resultIm, pointsDt);
        break;
      case '6':
        image.copyTo(resultIm);
        if (findX(image,xp)) {
          pointsDt.push_back(xp);
          circle(resultIm,xp,4,Scalar(0,0,255),2);
        }
        break;
      case '7':
      	if (mode_select == 3) {
 		  frameCount++;
          _rect = flow50.detectMain(image,resultIm,frameCount);
          position.x = _rect.x + _rect.width/2;
          position.y = _rect.y + _rect.height/2;
          if((position.x+position.y) != 0)
          {
        	pointsDt.push_back(position);
          }
      	} else if(mode_select == 2) {
      	  cx_detect(image, resultIm, pointsDt, axisbDt);    
      	} else if(mode_select == 1) {
      	  flow01(image, resultIm, pointsDt);
      	}
      	break;
      case '8':
        /*if (mode_select == 3) {
 		  frameCount++;
          _rect = flow50.detectMain(image,resultIm,frameCount);
          position.x = _rect.x + _rect.width/2;
          position.y = _rect.y + _rect.height/2;
          if((position.x+position.y) != 0)
           {
        	pointsDt.push_back(position);
           }
      	} else */
        if(mode_select == 2) {
      	  cx_detect(image, resultIm, pointsDt, axisbDt);
      	} else if(mode_select == 1) {
      	  image.copyTo(resultIm);
          if (findX(image,xp)) {
            pointsDt.push_back(xp);
            circle(resultIm,xp,4,Scalar(0,0,255),2);
          }
      	}
        break;
      default:
        break;
    }

    // cx_detect(image, resultIm, pointsDt, axisbDt);
    std_msgs::Float32MultiArray msg;
    std_msgs::Float32MultiArray bit_msg;
    static float qx = 67.0 / 57.3;
    static float qy = 37.7 / 57.3;
    static float obj_radius = 0.55;
    static float focal_len = 484;

    auto end = std::chrono::system_clock::now();

    if (pointsDt.size() > 0) {
      msg.data.push_back(1);
      msg.data.push_back(pointsDt[0].x);
      msg.data.push_back(pointsDt[0].y);
      msg.data.push_back(PROCESS_WIDTH);
      msg.data.push_back(PROCESS_HEIGHT);

      bit_msg.data.push_back( (pointsDt[0].x - PROCESS_WIDTH/2) / (float)PROCESS_WIDTH * qx );
      bit_msg.data.push_back( (pointsDt[0].y - PROCESS_HEIGHT/2) / (float)PROCESS_HEIGHT * qy );
      if (axisbDt.size() > 0)
        bit_msg.data.push_back( obj_radius * focal_len / axisbDt[0] );
      else
        bit_msg.data.push_back( 0 );
      bit_msg.data.push_back( std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() / 1000. );
      bit_msg.data.push_back( 1. );
      if (axisbDt.size() > 0)
        bit_msg.data.push_back( axisbDt[0] / PROCESS_WIDTH );
      else
        bit_msg.data.push_back( 0 );

      countsDt++;
    } else {
      msg.data.push_back(0);
      msg.data.push_back(0);
      msg.data.push_back(0);
      msg.data.push_back(0);
      msg.data.push_back(0);

      bit_msg.data.push_back( 0. );
      bit_msg.data.push_back( 0. );
      bit_msg.data.push_back( 0. );
      bit_msg.data.push_back( std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() / 1000. );
      bit_msg.data.push_back( 0. );
      bit_msg.data.push_back( 0. );
    }
    position_pub.publish(msg); 
    bit_pub.publish(bit_msg);
    ros::spinOnce();


    // cout<<4<<endl;
    

    // draw reference center
    if (drawReferenceCenter) {
    cv::line(resultIm, cv::Point(CenterX-40, CenterY), cv::Point(CenterX+40, CenterY), cv::Scalar(0,0,255));
    cv::line(resultIm, cv::Point(CenterX, CenterY-40), cv::Point(CenterX, CenterY+40), cv::Scalar(0,0,255));
    circle(resultIm, cv::Point(CenterX, CenterY), 40, cv::Scalar(0,0,255), 2);
    }
    
    if (drawReferenceCenter) {
    char dfps[128], buf[128];
    // Get time now for realtime video
    tm* local;
    time_t t = time(NULL);
    local = localtime(&t);
    strftime(buf, 64, "%Y-%m-%d %H:%M:%S", local);
    sprintf(dfps, "  DFPS:%d", detectionFps);
    cv::putText(resultIm, std::string(buf)+std::string(dfps), cv::Point(PROCESS_WIDTH*.05, PROCESS_HEIGHT*.05),
      2, 1.f, cv::Scalar(255, 255, 255));

    writer.PutOneFrame(Mat(resultIm));
    imshow("cam", resultIm);
    } else {
    	// cout<<5<<endl;
      writer.PutOneFrame(image);
      // cout<<6<<endl;
      cv::imshow("cam", image);
//cout<<7<<endl;
    }


    if(waitKey(1) == 27)
      break;
    // writer.PutOneFrame(Mat(resultIm));
    // writer.PutOneFrame(image);
//cout<<8<<endl;
    totTime += _toc();
    if (totTime >= 1000.0) {
      detectionFps = countsDt;
      std::cout << "DETECTION FPS:" << countsDt << std::endl;
      countsDt = 0; totTime = .0;
    }

loopRate.sleep();

  }

  return 0;
}


void cx_detect(Mat& srcIm, Mat3b& resultIm, vector<cv::Point>& pts, vector<float>& axis_bs) {
  Mat1b gray;
  vector<spire::Ellipse> ellsCned;
  cvtColor(srcIm, gray, COLOR_BGR2GRAY);
  cned.Detect(gray, ellsCned);

  resultIm = srcIm.clone();
  cned.DrawDetectedEllipses(resultIm, ellsCned);

  if (ellsCned.size() > 0) {
    Point pt;
    pt.x = ellsCned[0]._xc; pt.y = ellsCned[0]._yc;
    pts.push_back(pt);
    axis_bs.push_back(ellsCned[0]._b);
  }
}

void cx_init() {
  Size sz(PROCESS_WIDTH, PROCESS_HEIGHT);
  // Parameters Settings
  int    iThrLength = 14;
  float  fThrObb = 3.0f; // Discarded..
  float  fThrPos = 1.0f;

  float  fThrMinScore = 0.3f;
  float  fMinReliability = 0.7f;
  int    iNs = 16;
  float  fMaxCenterDistance = sqrt(float(sz.width*sz.width + sz.height*sz.height)) * 0.05f;
  Size   szPreProcessingGaussKernelSize = Size(5,5);
  double dPreProcessingGaussSigma = 1.0;
  float  fDistanceToEllipseContour = 0.1f;
  cned.SetParameters(szPreProcessingGaussKernelSize,
                     dPreProcessingGaussSigma,
                     fThrPos,
                     fMaxCenterDistance,
                     iThrLength,
                     fThrObb,
                     fDistanceToEllipseContour,
                     fThrMinScore,
                     fMinReliability,
                     iNs);
#ifdef DETECT_MARKER_X_NET
  cned.cvmat_classifier_.Init(model_file, trained_file, mean_file, label_file);
#endif
}

float TheMarkerSize;
void aruco_init(Mat& srcIm) {
	///////////  PARSE ARGUMENTS
    // string TheInputVideo = argv[1];
    // read camera parameters if passed
    // if (cml["-c"] )
    // TheCameraParameters.readFromXMLFile("/home/nuc2/jario_code/flycap2_calibration/build/calibration_info_ori.xml");
    TheMarkerSize = -1;// std::stof(cml("-s","-1"));
    // aruco::Dictionary::DICT_TYPES  TheDictionary= Dictionary::getTypeFromString( cml("-d","ARUCO") );
    //aruco::Dictionary::DICT_TYPES  TheDictionary = Dictionary::getTypeFromString( "ARUCO" );

    ///////////  OPEN VIDEO

    ///// CONFIGURE DATA
   // MDetector.setDictionary("ARUCO");//sets the dictionary to be employed (ARUCO,APRILTAGS,ARTOOLKIT,etc)
   //     MDetector.setThresholdParams(7, 7);
   //     MDetector.setThresholdParamRange(2, 0);
    //    if (TheCameraParameters.isValid())
      //      TheCameraParameters.resize(srcIm.size());
}

void aruco_detect(Mat& srcIm, Mat3b& resultIm, vector<cv::Point>& pts) {
	// cout<<1<<endl;
	// Detection of markers in the image passed
	/*
    TheMarkers= MDetector.detect(srcIm, TheCameraParameters, TheMarkerSize);
    // print marker info and draw the markers in image
    srcIm.copyTo(resultIm);
// cout<<2<<endl;
    // for (unsigned int i = 0; i < TheMarkers.size(); i++) {
        // n_marker ++;
    if (TheMarkers.size() > 0) {
        cv::Point2f _ptf = TheMarkers[0].getCenter();
        pts.push_back(_ptf);
        
        TheMarkers[0].draw(resultIm, Scalar(0, 0, 255));
    }
    TheMarkers.resize(1);
    // draw a 3d cube in each marker if there is 3d info
    
    if (TheCameraParameters.isValid() && TheMarkerSize>0)
        for (unsigned int i = 0; i < TheMarkers.size(); i++) {
            CvDrawingUtils::draw3dCube(resultIm, TheMarkers[i], TheCameraParameters);
            CvDrawingUtils::draw3dAxis(resultIm, TheMarkers[i], TheCameraParameters);
        }
        */
	// double tick = (double)getTickCount(); // for checking the speed
    // Detection of markers in the image passed

    cout<<"==================="<<endl;
    //vector< Marker > TheMarkers= MDetector.detect(srcIm);
    // chekc the speed by calculating the mean speed of all iterations
    // AvrgTime.first += ((double)getTickCount() - tick) / getTickFrequency();
    // AvrgTime.second++;
    // cout << "\rTime detection=" << 1000 * AvrgTime.first / AvrgTime.second << " milliseconds nmarkers=" << TheMarkers.size() << std::endl;

    // print marker info and draw the markers in image
    srcIm.copyTo(resultIm);
    cout<<"********************"<<endl;
/*
    for (unsigned int i = 0; i < TheMarkers.size(); i++) {
        cout << TheMarkers[i]<<endl;
        TheMarkers[i].draw(resultIm, Scalar(0, 0, 255));
    }

    // draw a 3d cube in each marker if there is 3d info
    if (TheCameraParameters.isValid() && TheMarkerSize>0)
        for (unsigned int i = 0; i < TheMarkers.size(); i++) {
            CvDrawingUtils::draw3dCube(resultIm, TheMarkers[i], TheCameraParameters);
            CvDrawingUtils::draw3dAxis(resultIm, TheMarkers[i], TheCameraParameters);
        }
*/
// cout<<3<<endl;
}

Mat gray, prevGray, flow, motion2color;
void flow01(Mat& srcIm, Mat3b& resultIm, vector<cv::Point>& pts) {
	vector<Point2f> prevPoint, currPoint;
	vector<uchar> state;
    vector<float> err;
	cvtColor(srcIm, gray, COLOR_BGR2GRAY);
	srcIm.copyTo(resultIm);
	if (prevGray.data)
	{
		goodFeaturesToTrack(prevGray, prevPoint, 500, 0.001, 10, Mat(), 3, false, 0.04);
		
	  	if (prevPoint.size() == 0) return;
        cornerSubPix(prevGray, prevPoint, Size(10,10), Size(-1,-1), TermCriteria(TermCriteria::COUNT | TermCriteria::EPS, 20, 0.03));
        calcOpticalFlowPyrLK(prevGray, gray, prevPoint, currPoint, state, err, Size(31,31), 3);

        int opticalCurrCenter[2] = {0};
        int opticalPrevCenter[2] = {0};
        int pointNum = 0;
        for(int i = 0; i < state.size(); i ++)
    	{
            if(state[i] != 0)
            {
                line(resultIm, Point((int)prevPoint[i].x, (int)prevPoint[i].y), Point((int)currPoint[i].x, (int)currPoint[i].y), Scalar(0, 255, 0));
                opticalPrevCenter[0] += int(prevPoint[i].x);
                opticalPrevCenter[1] += int(prevPoint[i].y);
                opticalCurrCenter[0] += int(currPoint[i].x);
                opticalCurrCenter[1] += int(currPoint[i].y);
                pointNum ++;
            }
        }
        if (pointNum != 0)
        {
        	opticalPrevCenter[0] /= pointNum;
        	opticalPrevCenter[1] /= pointNum;
        	opticalCurrCenter[0] /= pointNum;
        	opticalCurrCenter[1] /= pointNum;
        }

       	

        int opticalMean[2] = {0};

        opticalMean[0] = ((int)opticalCurrCenter[0] - (int)opticalPrevCenter[0]) * 3;
        opticalMean[1] = ((int)opticalCurrCenter[1] - (int)opticalPrevCenter[1]) * 3;

        if (opticalMean[0] < 10 && opticalMean[0] > -10)
        {
        	opticalMean[0] = 0;
        }
        if (opticalMean[1] < 10 && opticalMean[1] > -10)
        {
        	opticalMean[1] = 0;
        }

        circle(resultIm, Point(PROCESS_WIDTH / 2 + opticalMean[0], PROCESS_HEIGHT / 2 + opticalMean[1]), 5, Scalar(0, 255 ,255), 4);
        line(resultIm, Point(PROCESS_WIDTH / 2, PROCESS_HEIGHT / 2), Point(PROCESS_WIDTH / 2 + opticalMean[0], PROCESS_HEIGHT / 2 + opticalMean[1]), Scalar(0, 0, 255), 5);
		cv::Point pt1;

		pt1.x = (opticalMean[0] + PROCESS_WIDTH/2);
		pt1.y = (opticalMean[1] + PROCESS_HEIGHT/2);
		pts.push_back(pt1);
	}
	gray.copyTo(prevGray);
}
